# Convolution + batch normalization.
ConvBNLayer{outChannels, kernel, stride, bnTimeConst} = Sequential(
    ConvolutionalLayer{outChannels, kernel, init = "heNormal", stride = stride, pad = true, bias = false} :
    BatchNormalizationLayer{spatialRank = 2, normalizationTimeConstant = bnTimeConst, useCntkEngine = true}
)

# Convolution with kernel size 1 x 1 + batch normalization.
# Needs special treatment in order to explicitly disable padding (there seems to be a bug in CNTK where
# an assert in ConvolveGeometry.h fails incorrectly for pad == true, input size 32 x 32, kernel size 1 x 1,
# and stride 2).
Conv1x1BNLayer{outChannels, stride, bnTimeConst} = Sequential(
    ConvolutionalLayer{outChannels, (1:1), init = "heNormal", stride = stride, pad = false, bias = false} :
    BatchNormalizationLayer{spatialRank = 2, normalizationTimeConstant = bnTimeConst, useCntkEngine = true}
)

# Convolution + batch normalization + rectifier linear.
ConvBNReLULayer{outChannels, kernelSize, stride, bnTimeConst} = Sequential(
    ConvBNLayer{outChannels, kernelSize, stride, bnTimeConst} :
    ReLU
)

# The basic ResNet block contains two 3 x 3 convolutions, which is added to the orignal input
# of the block.
ResNetBasic{outChannels, bnTimeConst} = {
    apply(x) = {
        # Convolution
        b = Sequential(
            ConvBNReLULayer{outChannels, (3:3), (1:1), bnTimeConst} :
            ConvBNLayer{outChannels, (3:3), (1:1), bnTimeConst})(x)

        p = Plus(b, x)
        r = ReLU(p)
    }.r
}.apply

# A block to reduce the feature map resolution. Two 3 x 3 convolutions with stride, which is
# added to the original input with 1 x 1 convolution and stride.
ResNetBasicInc{outChannels, stride, bnTimeConst} = {
    apply(x) = {
        # Convolution
        b = Sequential(
            ConvBNReLULayer{outChannels, (3:3), stride, bnTimeConst} :
            ConvBNLayer{outChannels, (3:3), (1:1), bnTimeConst})(x)

        # Shortcut
        s = Conv1x1BNLayer{outChannels, stride, bnTimeConst}(x)

        p = Plus(b, s)
        r = ReLU(p)
    }.r
}.apply

# Convolution layer with 1x1 kernel whose weights are initialized using
# method of He et al. (https://arxiv.org/abs/1502.01852).
ConvLayer1x1MSRAInit(
    input,              # input node
    inputChannels,      # number of input channels
    outputChannels,     # number of output channels
    horizontalStride,   # horizontal stride
    verticalStride      # vertical stride
    ) = [
    kernelShape = (1:1:inputChannels)
    outputShape = (1:1:outputChannels)
    strideShape = (horizontalStride:verticalStride:inputChannels)
    W = Parameter(outputChannels, inputChannels, init ='heNormal')
    c = Convolution(W, input, kernelShape, mapDims = outputShape, stride = strideShape,
                    autoPadding = false, imageLayout = "cudnn")
].c

# Learnable upsampling layer initialized with bilinear weights.
LearnableUpsamplingLayer{
    inputChannels,      # number of input channels
    outputChannels,     # number of output channels
    kernelSize,         # kernel size (both horizontal and vertical)
    stride              # stride (both horizontal and vertical)
    } = ConvolutionTransposeLayer{
            outputChannels,
            (kernelSize:kernelSize),
            inputChannels,
            stride = stride,
            init = 'bilinear',
            bias = false,
            initValueScale = 1.0}

# Crop node with automatically computed offsets based on a least common ancestor of input nodes.
CropAutomatic(inputShape, outputShape, tag = '') = new ComputationNode [
    operation = 'Crop';
    inputs = (inputShape:outputShape) /*plus the function args*/
]

# Crop node with automatically computed offsets based on specified ancestors of input nodes.
CropAutomaticGivenAncestors(inputShape, outputShape, inputAncestor, outputAncestor, tag = '') = new ComputationNode [
    operation = 'Crop';
    inputs = (inputShape:outputShape:inputAncestor:outputAncestor) /*plus the function args*/
]

# Epoch accumulator node, currently used in mean IoU criteria (see below).
Accumulator(input, tag = '') = new ComputationNode [
    operation = 'EpochAccumulator';
    inputs = (input) /*plus the function args*/
]

# Hardmax along arbitrary axis.
HardMaxND(z, axis = 1) =
{
    maxVals = ReduceMax(z, axis = axis)
    # This could cause problems with multiple values matching maximum value.
    isMax = Equal(z, maxVals)
}.isMax

# Mean IoU error layer. Intersection-over-union (IoU) score for a given ground
# truth class is the ratio TP / (TP + FP + FN), where TP, FP, and FN is the
# number of true positive, false positive, and false negative pixels, respectively,
# for that class in the entire dataset (each number is accumulated independently
# across minibatches). Mean IoU score is the average IoU score over all classes.
# Mean IoU error is 1 minus the mean IoU score. Pixels with mask == 0 are ignored.
MeanIOUError(label, out, mask, classCount = 1) = [
    outHardmax = HardMaxND(out, axis = 3)
    outMasked = outHardmax .* mask
    labelMasked = label .* mask
    intersection = outMasked .* labelMasked
    union = (labelMasked + outMasked) - intersection
    intersectionFlat = FlattenDimensions(intersection, 0, 3)
    intersectionByClass = ReduceSum(intersectionFlat, axis = 1)
    unionFlat = FlattenDimensions(union, 0, 3)
    unionByClass = ReduceSum(unionFlat, axis = 1)
    i = Accumulator(intersectionByClass)
    u = Accumulator(unionByClass)
    reciprocalUnion = Reciprocal(u + ConstantTensor(0.00001, (1)))
    iou = i .* reciprocalUnion
    iouSum = ReduceSum(iou)
    norm = Reciprocal(ConstantTensor(classCount, (1)))
    miou = iouSum .* norm
    errMiou = BS.Constants.One - miou
].errMiou

# Pixelwise error (fraction of incorrectly classified pixels).
# Pixels with mask == 0 are ignored.
PixelError(label, out, mask) = [
    outHardmax = HardMaxND(out, axis = 3)
    pixelNorm = Reciprocal(ReduceSum(mask)) # 1 / pixel_count
    acc = ReduceSum(ElementTimes(label, outHardmax), axis = 3)
    diffs = Minus(BS.Constants.One, acc) .* mask
    errSum = ReduceSum(diffs)
    err = errSum .* pixelNorm
].err

# Cross entropy with softmax along arbitrary axis, where loss value
# is normalized by the number of valid pixels in minibatch. Valid pixels
# are those with mask == 1.
CrossEntropyWithSoftmaxNDNormalized(label, out, mask, axis = 1, tag = '') = [
    out_max = ReduceMax(out, axis = axis)
    out_shift = Minus(out, out_max)
    log_sum = ReduceLogSum(out_shift, axis = axis)
    logits_per_class = ElementTimes(label, out_shift)
    logits = ReduceSum(logits_per_class, axis = axis)
    diff = Minus(log_sum, logits)
    diff_valid = ElementTimes(diff, mask)
    ce_unnorm = SumElements(diff_valid)
    norm_factor = Reciprocal(SumElements(mask))
    ce = ElementTimes(ce_unnorm, norm_factor, tag = tag)
].ce 
